/*
 * Copyright 2018- The Pixie Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 */

#include "src/stirling/source_connectors/socket_tracer/protocols/amqp/stitcher.h"
#include <utility>
#include <vector>
#include "src/common/base/types.h"
#include "src/common/testing/testing.h"
#include "src/stirling/source_connectors/socket_tracer/protocols/amqp/decode.h"
#include "src/stirling/source_connectors/socket_tracer/protocols/amqp/types_gen.h"

namespace px {
namespace stirling {
namespace protocols {
namespace amqp {

using ::testing::ElementsAre;
using ::testing::IsEmpty;
using ::px::operator<<;
#define PX_ASSIGN_OR_RETURN_INVALID(expr, val_or) \
  PX_ASSIGN_OR(expr, val_or, return ParseState::kInvalid)

template <size_t N>
Frame GenFrame(message_type_t type, const uint8_t (&raw_packet)[N], uint8_t channel) {
  Frame result;
  auto packet_view = CreateStringView<char>(CharArrayStringView<uint8_t>(raw_packet));
  ParseState parse_state = ParseFrame(type, &packet_view, &result);
  EXPECT_EQ(parse_state, ParseState::kSuccess);
  result.channel = channel;
  return result;
}

constexpr uint8_t content_header_packet[] = {0x02, 0x00, 0x01, 0x00, 0x00, 0x00, 0x19, 0x00, 0x3c,
                                             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
                                             0x0a, 0x80, 0x00, 0x0a, 0x74, 0x65, 0x78, 0x74, 0x2f,
                                             0x70, 0x6c, 0x61, 0x69, 0x6e, 0xce};

constexpr uint8_t content_body_packet[] = {0x03, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0a, 0x77, 0x5a,
                                           0x74, 0x35, 0x53, 0x34, 0x7a, 0x6a, 0x68, 0x44, 0xce};

constexpr uint8_t basic_publish_packet[] = {0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0e, 0x00,
                                            0x3c, 0x00, 0x28, 0x00, 0x00, 0x00, 0x05, 0x68,
                                            0x65, 0x6c, 0x6c, 0x6f, 0x00, 0xce};

constexpr uint8_t basic_deliver_packet_response[] = {
    0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x40, 0x00, 0x3c, 0x00, 0x3c, 0x2b, 0x63, 0x74, 0x61,
    0x67, 0x2d, 0x2f, 0x74, 0x6d, 0x70, 0x2f, 0x67, 0x6f, 0x2d, 0x62, 0x75, 0x69, 0x6c, 0x64,
    0x34, 0x35, 0x38, 0x31, 0x37, 0x33, 0x31, 0x38, 0x39, 0x2f, 0x62, 0x30, 0x30, 0x31, 0x2f,
    0x65, 0x78, 0x65, 0x2f, 0x72, 0x65, 0x63, 0x76, 0x2d, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00,
    0x00, 0x00, 0x12, 0x00, 0x00, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0xce};

constexpr uint8_t connection_tune_packet[] = {0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0c,
                                              0x00, 0x0a, 0x00, 0x1e, 0x07, 0xff, 0x00,
                                              0x02, 0x00, 0x00, 0x00, 0x3c, 0xce};

constexpr uint8_t connect_tune_ok_packet[] = {0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0c,
                                              0x00, 0x0a, 0x00, 0x1f, 0x07, 0xff, 0x00,
                                              0x02, 0x00, 0x00, 0x00, 0x0a, 0xce};

constexpr uint8_t connect_open_packet[] = {0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00,
                                           0x0a, 0x00, 0x28, 0x01, 0x2f, 0x00, 0x00, 0xce};

constexpr uint8_t connect_open_ok_packet[] = {0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05,
                                              0x00, 0x0a, 0x00, 0x29, 0x00, 0xce};

constexpr uint8_t connection_start_packet[] = {
    0x01, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x0a, 0x00, 0x0a, 0x00, 0x09, 0x00, 0x00, 0x01,
    0xdb, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x46, 0x00,
    0x00, 0x00, 0xc7, 0x12, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x73, 0x68, 0x65, 0x72, 0x5f, 0x63, 0x6f,
    0x6e, 0x66, 0x69, 0x72, 0x6d, 0x73, 0x74, 0x01, 0x1a, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67,
    0x65, 0x5f, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x5f, 0x62, 0x69, 0x6e, 0x64, 0x69,
    0x6e, 0x67, 0x73, 0x74, 0x01, 0x0a, 0x62, 0x61, 0x73, 0x69, 0x63, 0x2e, 0x6e, 0x61, 0x63, 0x6b,
    0x74, 0x01, 0x16, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6d, 0x65, 0x72, 0x5f, 0x63, 0x61, 0x6e, 0x63,
    0x65, 0x6c, 0x5f, 0x6e, 0x6f, 0x74, 0x69, 0x66, 0x79, 0x74, 0x01, 0x12, 0x63, 0x6f, 0x6e, 0x6e,
    0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x65, 0x64, 0x74, 0x01,
    0x13, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6d, 0x65, 0x72, 0x5f, 0x70, 0x72, 0x69, 0x6f, 0x72, 0x69,
    0x74, 0x69, 0x65, 0x73, 0x74, 0x01, 0x1c, 0x61, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63,
    0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x66, 0x61, 0x69, 0x6c, 0x75, 0x72, 0x65, 0x5f, 0x63, 0x6c,
    0x6f, 0x73, 0x65, 0x74, 0x01, 0x10, 0x70, 0x65, 0x72, 0x5f, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6d,
    0x65, 0x72, 0x5f, 0x71, 0x6f, 0x73, 0x74, 0x01, 0x0f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x5f,
    0x72, 0x65, 0x70, 0x6c, 0x79, 0x5f, 0x74, 0x6f, 0x74, 0x01, 0x0c, 0x63, 0x6c, 0x75, 0x73, 0x74,
    0x65, 0x72, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x53, 0x00, 0x00, 0x00, 0x19, 0x72, 0x61, 0x62, 0x62,
    0x69, 0x74, 0x40, 0x62, 0x6f, 0x6d, 0x62, 0x65, 0x2e, 0x70, 0x69, 0x78, 0x69, 0x65, 0x6c, 0x61,
    0x62, 0x73, 0x2e, 0x61, 0x69, 0x09, 0x63, 0x6f, 0x70, 0x79, 0x72, 0x69, 0x67, 0x68, 0x74, 0x53,
    0x00, 0x00, 0x00, 0x37, 0x43, 0x6f, 0x70, 0x79, 0x72, 0x69, 0x67, 0x68, 0x74, 0x20, 0x28, 0x63,
    0x29, 0x20, 0x32, 0x30, 0x30, 0x37, 0x2d, 0x32, 0x30, 0x32, 0x32, 0x20, 0x56, 0x4d, 0x77, 0x61,
    0x72, 0x65, 0x2c, 0x20, 0x49, 0x6e, 0x63, 0x2e, 0x20, 0x6f, 0x72, 0x20, 0x69, 0x74, 0x73, 0x20,
    0x61, 0x66, 0x66, 0x69, 0x6c, 0x69, 0x61, 0x74, 0x65, 0x73, 0x2e, 0x0b, 0x69, 0x6e, 0x66, 0x6f,
    0x72, 0x6d, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x53, 0x00, 0x00, 0x00, 0x39, 0x4c, 0x69, 0x63, 0x65,
    0x6e, 0x73, 0x65, 0x64, 0x20, 0x75, 0x6e, 0x64, 0x65, 0x72, 0x20, 0x74, 0x68, 0x65, 0x20, 0x4d,
    0x50, 0x4c, 0x20, 0x32, 0x2e, 0x30, 0x2e, 0x20, 0x57, 0x65, 0x62, 0x73, 0x69, 0x74, 0x65, 0x3a,
    0x20, 0x68, 0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x72, 0x61, 0x62, 0x62, 0x69, 0x74, 0x6d,
    0x71, 0x2e, 0x63, 0x6f, 0x6d, 0x08, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x53, 0x00,
    0x00, 0x00, 0x11, 0x45, 0x72, 0x6c, 0x61, 0x6e, 0x67, 0x2f, 0x4f, 0x54, 0x50, 0x20, 0x32, 0x34,
    0x2e, 0x32, 0x2e, 0x31, 0x07, 0x70, 0x72, 0x6f, 0x64, 0x75, 0x63, 0x74, 0x53, 0x00, 0x00, 0x00,
    0x08, 0x52, 0x61, 0x62, 0x62, 0x69, 0x74, 0x4d, 0x51, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f,
    0x6e, 0x53, 0x00, 0x00, 0x00, 0x06, 0x33, 0x2e, 0x39, 0x2e, 0x31, 0x33, 0x00, 0x00, 0x00, 0x0e,
    0x50, 0x4c, 0x41, 0x49, 0x4e, 0x20, 0x41, 0x4d, 0x51, 0x50, 0x4c, 0x41, 0x49, 0x4e, 0x00, 0x00,
    0x00, 0x05, 0x65, 0x6e, 0x5f, 0x55, 0x53, 0xce};

constexpr uint8_t connection_start_ok_packet[] = {
    0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0xa1, 0x00, 0x0a, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x7d, 0x07,
    0x70, 0x72, 0x6f, 0x64, 0x75, 0x63, 0x74, 0x53, 0x00, 0x00, 0x00, 0x21, 0x68, 0x74, 0x74, 0x70,
    0x73, 0x3a, 0x2f, 0x2f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73,
    0x74, 0x72, 0x65, 0x61, 0x64, 0x77, 0x61, 0x79, 0x2f, 0x61, 0x6d, 0x71, 0x70, 0x07, 0x76, 0x65,
    0x72, 0x73, 0x69, 0x6f, 0x6e, 0x53, 0x00, 0x00, 0x00, 0x02, 0xce, 0xb2, 0x0c, 0x63, 0x61, 0x70,
    0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x46, 0x00, 0x00, 0x00, 0x2e, 0x12, 0x63,
    0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x65,
    0x64, 0x74, 0x01, 0x16, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6d, 0x65, 0x72, 0x5f, 0x63, 0x61, 0x6e,
    0x63, 0x65, 0x6c, 0x5f, 0x6e, 0x6f, 0x74, 0x69, 0x66, 0x79, 0x74, 0x01, 0x05, 0x50, 0x4c, 0x41,
    0x49, 0x4e, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x67, 0x75, 0x65, 0x73, 0x74, 0x00, 0x67, 0x75, 0x65,
    0x73, 0x74, 0x05, 0x65, 0x6e, 0x5f, 0x55, 0x53, 0xce};

// Test both sync and async and check packets matched and parsed correctly
TEST(AMQPFrameDecoderTest, BasicSyncMatching) {
  std::deque<Frame> req_packets;
  std::deque<Frame> resp_packets;
  RecordsWithErrorCount<Record> result;

  result = StitchFrames(&req_packets, &resp_packets);
  EXPECT_TRUE(resp_packets.empty());
  EXPECT_TRUE(req_packets.empty());
  EXPECT_EQ(result.error_count, 0);
  EXPECT_EQ(result.records.size(), 0);

  // Add async packets to both sides
  req_packets.push_back(GenFrame(message_type_t::kRequest, connection_start_ok_packet, 1));
  req_packets.push_back(GenFrame(message_type_t::kRequest, connection_tune_packet, 1));
  req_packets.push_back(GenFrame(message_type_t::kRequest, connect_open_packet, 1));
  req_packets.push_back(GenFrame(message_type_t::kRequest, connection_tune_packet, 1));

  req_packets.push_back(GenFrame(message_type_t::kRequest, connection_tune_packet, 2));
  req_packets.push_back(
      GenFrame(message_type_t::kRequest, connection_tune_packet, 2));  // left behind.
  req_packets.push_back(GenFrame(message_type_t::kRequest, connect_open_packet, 2));

  // Unmatched request test
  req_packets.push_back(
      GenFrame(message_type_t::kRequest, connect_open_packet, 6));  // left behind.

  resp_packets.push_back(GenFrame(message_type_t::kResponse, connection_start_packet, 1));
  resp_packets.push_back(GenFrame(message_type_t::kResponse, connect_tune_ok_packet, 1));
  resp_packets.push_back(GenFrame(message_type_t::kResponse, connect_open_ok_packet, 1));
  resp_packets.push_back(GenFrame(message_type_t::kResponse, connect_tune_ok_packet, 1));

  resp_packets.push_back(GenFrame(message_type_t::kResponse, connect_tune_ok_packet, 2));
  resp_packets.push_back(GenFrame(message_type_t::kResponse, connect_open_ok_packet, 2));

  // Extraneous packets should be left in map, but should not count as error.
  resp_packets.push_back(
      GenFrame(message_type_t::kResponse, connect_open_ok_packet, 7));  // left behind.
  resp_packets.push_back(
      GenFrame(message_type_t::kResponse, connect_open_ok_packet, 8));  // left behind.

  result = StitchFrames(&req_packets, &resp_packets);
  // All 10 packets added to solution
  EXPECT_EQ(result.error_count, 0);
  EXPECT_EQ(result.records.size(), 6);

  EXPECT_EQ(req_packets.size(), 2);
  EXPECT_EQ(resp_packets.size(), 2);
}

// Test both sync and async and check packets matched and parsed correctly
TEST(AMQPFrameDecoderTest, AsyncMatching) {
  std::deque<Frame> req_packets;
  std::deque<Frame> resp_packets;
  RecordsWithErrorCount<Record> result;

  result = StitchFrames(&req_packets, &resp_packets);
  EXPECT_TRUE(resp_packets.empty());
  EXPECT_TRUE(req_packets.empty());
  EXPECT_EQ(result.error_count, 0);
  EXPECT_EQ(result.records.size(), 0);

  // Add async packets to both sides
  req_packets.push_back(GenFrame(message_type_t::kRequest, content_header_packet, 0));
  req_packets.push_back(GenFrame(message_type_t::kRequest, basic_publish_packet, 0));
  req_packets.push_back(GenFrame(message_type_t::kRequest, basic_deliver_packet_response, 0));
  req_packets.push_back(GenFrame(message_type_t::kRequest, content_body_packet, 2));

  // // In reality, only deliver would be a response. For testing purposes, trying all the types
  resp_packets.push_back(GenFrame(message_type_t::kResponse, content_header_packet, 5));
  resp_packets.push_back(GenFrame(message_type_t::kResponse, basic_deliver_packet_response, 2));
  resp_packets.push_back(GenFrame(message_type_t::kResponse, content_body_packet, 2));
  resp_packets.push_back(GenFrame(message_type_t::kResponse, basic_deliver_packet_response, 3));

  result = StitchFrames(&req_packets, &resp_packets);

  // All 10 packets added to solution
  EXPECT_EQ(result.error_count, 0);
  EXPECT_EQ(result.records.size(), 8);
  EXPECT_EQ(req_packets.size(), 0);
  EXPECT_EQ(resp_packets.size(), 0);
}

}  // namespace amqp
}  // namespace protocols
}  // namespace stirling
}  // namespace px
